import numpy as np
import pandas as pd
import argparse
import random
import pulp
import math
from collections import defaultdict
import time
import multiprocessing as mp
from functools import partial

def x_rule(t, bt, swa_ctr_list, win_ecpm_list, second_ecpm_list, is_click_vec, is_win_list):
    # Determine whether the bid wins based on the given criteria and click information
    if is_win_list[t] == 0:
        is_win_exp = int(swa_ctr_list[t] * bt > win_ecpm_list[t])
    else:
        is_win_exp = int(swa_ctr_list[t] * bt > second_ecpm_list[t])
    
    # Return the expected click outcome based on win condition
    xb = is_click_vec[t] * is_win_exp
    return xb

def q_rule(t, bt, swa_ctr_list, win_ecpm_list, second_ecpm_list, is_click_vec, is_win_list):
    # Calculate the expected payment given the current bid and click conditions
    if is_win_list[t] == 0:
        is_win_exp = int(swa_ctr_list[t] * bt > win_ecpm_list[t])
        payment = is_win_exp * is_click_vec[t] * win_ecpm_list[t] / swa_ctr_list[t]
    else:
        is_win_exp = int(swa_ctr_list[t] * bt > second_ecpm_list[t])
        payment = is_win_exp * is_click_vec[t] * second_ecpm_list[t] / swa_ctr_list[t]
    return payment

def solve_w_star(v_t_star, x_t_star, q_t_star, rho, tcpa, B_size):
    # Set up and solve the optimization problem to determine w_star
    problem = pulp.LpProblem("w_star_optimization", pulp.LpMaximize)
    w_vars = [pulp.LpVariable(f"w_{i}", lowBound=0, cat=pulp.LpContinuous) for i in range(B_size)]
    problem += pulp.lpSum(w_vars) == 1, "Sum_to_One_Constraint"
    
    # Define the objective function to maximize the weighted sum
    objective = pulp.lpSum([w_vars[i] * v_t_star * x_t_star[i] for i in range(B_size)])
    problem += objective, "Total_Objective"
    
    # Add constraints related to the optimization problem
    lhs_constraint1 = pulp.lpSum([w_vars[i] * q_t_star[i] for i in range(B_size)])
    rhs_constraint1 = pulp.lpSum([w_vars[i] * v_t_star * x_t_star[i] * tcpa for i in range(B_size)])
    problem += lhs_constraint1 <= rhs_constraint1, "Constraint1"
    
    constraint2 = pulp.lpSum([w_vars[i] * q_t_star[i] for i in range(B_size)]) <= rho
    problem += constraint2, "Constraint2"
    
    # Solve the problem
    problem.solve(pulp.PULP_CBC_CMD(msg=0, options=['randomSeed=1']))
    
    # Check if an optimal solution was found
    if pulp.LpStatus[problem.status] == "Optimal":
        w_solution = [pulp.value(w_vars[i]) for i in range(B_size)]
        optimal_objective = pulp.value(objective)
        return w_solution
    else:
        raise ValueError("Optimization failed, no optimal or feasible solution found.")


def process_sku(sku_group, T, bid_num, new_class_size, seed):
    sku_id, df = sku_group
    
    # Initialize all necessary lists for storing data
    campaign_id_list = []
    timestamp_list = []
    sku_id_list = []
    pricing_ratio_list = []
    swa_ctr_list = []
    swa_cvr_list = []
    swa_gmv_list = []
    bid_price_list = []
    budget_list = []
    win_ecpm_list = []
    second_ecpm_list = []
    is_win_list = []
    class_value_list = []
    average_list = []
    average_alpha_list = []
    max_bid_list = []
    d_list = []
    class_size_list_all = []
    cvr_cali_list = []
    max_bid_old_list = []
    class_value_old_list = []
    class_size_old_list = []
    class_size_list = [0]*new_class_size

    # Sample the dataset
    df = df.sample(n=T, random_state=seed)

    # Populate the lists with values from the sampled dataframe
    for index, row in df.iterrows():
        campaign_id = row['campaign_id']#
        timestamp = row['timestamp']#
        sku_id = row['sku_id']
        pricing_ratio = row['pricing_ratio']#
        swa_ctr = row['swa_ctr']
        swa_cvr = row['swa_cvr']
        swa_gmv = row['swa_gmv']
        bid_price = row['bid_price']#
        budget = row['budget']
        win_ecpm = row['win_ecpm']
        second_ecpm = row['second_ecpm']
        is_win = row['is_win']
        class_value = row['class_new']
        average = row['average']#
        average_alpha = row['average_alpha']##
        max_bid = row[ 'max_bid_new']
        class_size = row['class_size_new']
        d = row['d']
        tROI = row['t_ROI']
        max_bid_old = row['max_bid']#
        class_value_old = row['class']#
        class_size_old = row['class_size']#
        campaign_id_list.append(campaign_id)
        timestamp_list.append(timestamp)
        sku_id_list.append(sku_id)
        pricing_ratio_list.append(pricing_ratio)
        swa_ctr_list.append(swa_ctr)
        swa_cvr_list.append(swa_cvr)
        swa_gmv_list.append(swa_gmv)
        bid_price_list.append(bid_price)
        budget_list.append(budget)
        win_ecpm_list.append(win_ecpm)
        second_ecpm_list.append(second_ecpm)
        is_win_list.append(is_win)
        class_value_list.append(class_value)
        average_list.append(average)
        average_alpha_list.append(average_alpha)
        class_size_list_all.append(class_size)
        d_list.append(d)
        cvr_cali_list.append(swa_cvr+d)
        max_bid_list.append(max_bid)
        max_bid_old_list.append(max_bid_old) #
        class_value_old_list.append(class_value_old)
        class_size_old_list.append(class_size_old)

    # Initialize budget
    budget_total = budget
    budget_t_cali = budget_total 
    budget_t_true = budget_total 
    budget_t_pred = budget_total
    budget_t_ucb = budget_total

    
    # Define constants in the dual mirror descent algorithms
    alpha = 1 / np.sqrt(T)
    beta = 1/((1+(budget_total/T)**2)*np.sqrt(T))

    # Initialize lambda and mu parameters for various dual mirror descent algorithms
    lamba_cali = 1 
    lamba_true = 1 
    lamba_pred = 1 

    mu_cali = 0
    mu_true = 0
    mu_pred = 0

    # Calculate class size distribution
    for i in range(1, new_class_size+1):
        class_size_list[i-1] = class_value_list.count(i)

    # Initialize variables for bidding strategies
    count_class = [0]*new_class_size
    w_star_list = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    x_t_hat = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    q_t_hat = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    x_star= [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    q_star = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    w_star = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    q_t_list = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    x_t_list = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
    win_count = [0]*new_class_size
    v_hat = [0]*new_class_size
    v_star = [0]*new_class_size

    # Initialize click vector and bid price lists
    is_click_vec = [0]*T
    for t in range(T):
        is_click_vec[t] = int(np.random.rand() < swa_ctr_list[t])

        B_list = [[0 for _ in range(bid_num)] for _ in range(new_class_size)]
        B_list[class_value_list[t]-1] = np.linspace(0, max_bid_list[t], bid_num).tolist()
    
    results = []

    # Loop through each timestep to compute bid strategies and results
    for t in range(T):
        # --------------------------------------- ucb ---------------------------------------
        start_ucb = time.time()
        is_click = is_click_vec[t]

        b = count_class[class_value_list[t]-1]
        a = b+1
        cls_idx = class_value_list[t] - 1
        if a == 1:
            bt = B_list[class_value_list[t]-1][bid_num-1]
            win_count[class_value_list[t]-1] = 0 
            x_t = x_rule(t,bt,swa_ctr_list,win_ecpm_list,second_ecpm_list,is_click_vec, is_win_list)
            if is_win_list[t] == 0:
                is_win_exp_ucb  = int(swa_ctr_list[t]*bt > win_ecpm_list[t])
            else:
                is_win_exp_ucb  = int(swa_ctr_list[t]*bt > second_ecpm_list[t])

            # Fill x_t_hat
            for i in range(bid_num):
                x_t_hat[cls_idx][i] = x_rule(
                    t, B_list[cls_idx][i],
                    swa_ctr_list, win_ecpm_list, second_ecpm_list, is_click_vec, is_win_list
                )

            q_t = q_rule(t,bt,swa_ctr_list,win_ecpm_list,second_ecpm_list,is_click_vec, is_win_list)

            # Fill q_t_hat
            for i in range(bid_num):
                q_t_hat[cls_idx][i] = q_rule(
                    t, B_list[cls_idx][i],
                    swa_ctr_list, win_ecpm_list, second_ecpm_list, is_click_vec
                , is_win_list)

            if q_t > 0:
                v = int(np.random.rand() < average_alpha_list[t])
                win_count[cls_idx] += 1
                v_hat[cls_idx] = v
        else:
            log_term = math.log(2*bid_num * class_size_list[cls_idx], 2)
            delta = (1 / (2 * a - 2) * log_term) ** 0.5
            # Update x_star
            x_star[cls_idx] = [min(x + delta, 1) for x in x_t_hat[cls_idx]]

            # Update v_star
            if win_count[cls_idx] == 0:
                v_star[cls_idx] = 1
            else:
                v_star[cls_idx] = min(1, v_hat[cls_idx] + (math.log(2 * class_size_list[cls_idx],2) / win_count[cls_idx]) ** 0.5)

            # Update q_star
            q_star[cls_idx] = [max(0, q - delta) for q in q_t_hat[cls_idx]]

            # Update w_star
            w_star[cls_idx] = solve_w_star(v_star[cls_idx], x_star[cls_idx], q_star[cls_idx], budget_total / T, swa_gmv/tROI, bid_num)

            probs = np.array(w_star[class_value_list[t]-1])
            result = np.random.choice(bid_num, p=probs)
            bt = B_list[class_value_list[t]-1][result]
            if is_win_list[t] == 0:
                is_win_exp_ucb  = int(swa_ctr_list[t]*bt > win_ecpm_list[t])
            else:
                is_win_exp_ucb  = int(swa_ctr_list[t]*bt > second_ecpm_list[t])
            q_t = q_rule(t,bt,swa_ctr_list,win_ecpm_list,second_ecpm_list,is_click_vec, is_win_list)

            for i in range(bid_num):
                bid = B_list[cls_idx][i]
                x_t_list[cls_idx][i] = x_rule(t, bid, swa_ctr_list, win_ecpm_list, second_ecpm_list, is_click_vec, is_win_list)
                q_t_list[cls_idx][i] = q_rule(t, bid, swa_ctr_list, win_ecpm_list, second_ecpm_list, is_click_vec, is_win_list)

            # 更新 x_t_hat 和 q_t_hat
            for i in range(bid_num):
                x_t_hat[cls_idx][i] = ((a - 1) * x_t_hat[cls_idx][i] + x_t_list[cls_idx][i]) / a
                q_t_hat[cls_idx][i] = ((a - 1) * q_t_hat[cls_idx][i] + q_t_list[cls_idx][i]) / a

            if q_t > 0:
                v = int(np.random.rand() < average_alpha_list[t])
                win_count[class_value_list[t]-1] += 1
                v_hat[class_value_list[t]-1]= ((win_count[class_value_list[t]-1]-1)*v_hat[class_value_list[t]-1] + v)/win_count[class_value_list[t]-1]

        payment_ucb = q_t
        bid_ucb = bt
        budget_t_ucb = budget_t_ucb - payment_ucb
        flag_ucb = is_win_exp_ucb * is_click * average_alpha_list[t]
        count_class[class_value_list[t]-1] += 1
        end_ucb = time.time()
        time_ucb = end_ucb-start_ucb
        
        # --------------------------------------- adjust ---------------------------------------
        start_cali = time.time()
        if budget_t_cali >= max(win_ecpm_list[t]/swa_ctr_list[t], second_ecpm_list[t]/swa_ctr_list[t]):
            bid_cali = (1+lamba_cali)/(mu_cali+lamba_cali) * swa_gmv/tROI * cvr_cali_list[t]
        else:
            bid_cali = 0
        if is_win_list[t] == 0:
            is_win_exp_cali = int(swa_ctr_list[t]*bid_cali > win_ecpm_list[t])
            payment_cali = is_win_exp_cali * is_click * win_ecpm_list[t]/swa_ctr_list[t]
        else:
            is_win_exp_cali = int(swa_ctr_list[t]*bid_cali > second_ecpm_list[t])
            payment_cali = is_win_exp_cali * is_click * second_ecpm_list[t]/swa_ctr_list[t]
        flag_cali = is_win_exp_cali * is_click * average_alpha_list[t]
        
        gb_cali = swa_gmv/tROI * cvr_cali_list[t]*is_win_exp_cali *is_click - payment_cali
        lamba_cali = lamba_cali * np.exp(-alpha * gb_cali)
        gb_prime_cali = budget_total / T - payment_cali
        mu_cali = max(0.0, mu_cali - beta * gb_prime_cali)
        budget_t_cali = budget_t_cali - payment_cali
        end_cali = time.time()
        time_cali = end_cali-start_cali

        # --------------------------------------- true ---------------------------------------
        start_true = time.time()
        if budget_t_true >= max(win_ecpm_list[t]/swa_ctr_list[t], second_ecpm_list[t]/swa_ctr_list[t]):
            bid_true = (1+lamba_true)/(mu_true+lamba_true) * swa_gmv/tROI * average_alpha_list[t]
        else:
            bid_true = 0
        if is_win_list[t] == 0:
            is_win_exp_true = int(swa_ctr_list[t]*bid_true > win_ecpm_list[t])
            payment_true = is_win_exp_true * is_click * win_ecpm_list[t]/swa_ctr_list[t]
        else:
            is_win_exp_true = int(swa_ctr_list[t]*bid_true > second_ecpm_list[t])
            payment_true = is_win_exp_true * is_click * second_ecpm_list[t]/swa_ctr_list[t]
        flag_true = is_win_exp_true * is_click * average_alpha_list[t]
        
        gb_true = swa_gmv/tROI * average_alpha_list[t]*is_win_exp_true*is_click - payment_true
        lamba_true = lamba_true * np.exp(-alpha * gb_true)
        gb_prime_true = budget_total / T - payment_true 
        mu_true = max(0.0, mu_true - beta * gb_prime_true)
        budget_t_true = budget_t_true - payment_true 
        end_true = time.time()
        time_true = end_true-start_true

        # --------------------------------------- pred ---------------------------------------
        start_pred = time.time()
        if budget_t_pred >= max(win_ecpm_list[t]/swa_ctr_list[t], second_ecpm_list[t]/swa_ctr_list[t]):
            bid_pred = (1+lamba_pred)/(mu_pred+lamba_pred) * swa_gmv/tROI * swa_cvr_list[t]
        else:
            bid_pred = 0
        if is_win_list[t] == 0:
            is_win_exp_pred = int(swa_ctr_list[t]*bid_pred > win_ecpm_list[t])
            payment_pred = is_win_exp_pred * is_click * win_ecpm_list[t]/swa_ctr_list[t]
        else:
            is_win_exp_pred = int(swa_ctr_list[t]*bid_pred > second_ecpm_list[t])
            payment_pred = is_win_exp_pred * is_click * second_ecpm_list[t]/swa_ctr_list[t]
        flag_pred = is_win_exp_pred * is_click * average_alpha_list[t]
        
        gb_pred = swa_gmv/tROI * swa_cvr_list[t]*is_win_exp_pred*is_click - payment_pred
        lamba_pred = lamba_pred * np.exp(-alpha * gb_pred)
        gb_prime_pred = budget_total / T - payment_pred
        mu_pred = max(0.0, mu_pred - beta * gb_prime_pred)
        budget_t_pred = budget_t_pred - payment_pred
        end_pred = time.time()
        time_pred = end_pred-start_pred
        
        # get results
        result_line = f"{sku_id},{t},{swa_gmv}, {tROI}, {is_click},{is_win_exp_ucb},{is_win_exp_pred},{is_win_exp_cali},{is_win_exp_true},{payment_ucb},{bid_ucb},{payment_pred},{bid_pred}, \
            {payment_cali},{bid_cali}, {payment_true},{bid_true},{budget_t_ucb},{budget_t_pred}, {budget_t_cali}, {budget_t_true}, {swa_cvr_list[t]}, \
            {cvr_cali_list[t]} ,{average_alpha_list[t]},{flag_ucb},{flag_pred}, {flag_cali}, {flag_true}, {max_bid_list[t]}, {time_ucb}, {time_pred}, {time_cali}, {time_true}"
        results.append(result_line)
    
    return results

if __name__ == "__main__":
    # Setup and argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument('--T', type=int, required=True)
    parser.add_argument('--file_name', type=str, required=True)
    parser.add_argument('--bid_num', type=int, required=True)
    parser.add_argument('--num_processes', type=int, default=32, help='Number of processes to use (default: 32)')
    
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    
    T = args.T 
    df = pd.read_csv(args.file_name)
    grouped = df.groupby('sku_id')
    bid_num = args.bid_num
    new_class_size = 1
    
    # Convert grouped data to a list for parallel processing
    sku_groups = list(grouped)
    
    # Create a pool of processes for parallel execution
    pool = mp.Pool(processes=args.num_processes)
    
    # Create partial function with fixed parameters
    process_func = partial(process_sku, T=T, bid_num=bid_num, new_class_size=new_class_size, seed=args.seed)
    
    # Execute parallel processing
    all_results = pool.map(process_func, sku_groups)
    
    # Close the pool and wait for processes to finish
    pool.close()
    pool.join()
    
    # Print all results
    for sku_results in all_results:
        for result in sku_results:
            print(result)